#!/usr/bin/env pmpython
"""Test pmdaopenmetrics via exposing fake endpoints -*- python -*- """
#
# Copyright (C) 2017-2019 Red Hat.
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
# or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
# for more details.
#

import time
import re
from six.moves import queue
import threading
import socket
from six.moves import BaseHTTPServer
import argparse


activeEndpoint = dict()


def write_endpoint_metadata(args=None, endpoint=None):
    endpointStr = "http://{}:{}/{}{}".format("localhost", args.addr[1], args.url, endpoint)
    f = open('{}/{}'.format(str(args.output), "source"+str(endpoint)+".url"), 'w')
    f.write("{}\n".format(endpointStr))
    f.close()


class FakeEndpoint(BaseHTTPServer.BaseHTTPRequestHandler):
    def sample_gauge_data(self, iteration=None, instances=None, metrics=None, endpointNum=None):
        gauge_value = 12.1
        instance_scale = 1
        instance_count = 0
        gauge_string = "# HELP sample_gauge{:04d} Sample gauge metric.\n".format(endpointNum)
        gauge_string += "# HELP sample_gauge{:04d} instance scale {}, value scale {}\n".format(endpointNum, instance_scale, gauge_value)
        gauge_string += "# Type sample_gauge{:04d} gauge\n".format(endpointNum)
        for i in range(instances):
            gauge_string += "sample_gauge{:04d}{{bar=\"{}\"}} {:.1f}\n".format(endpointNum, (i * instance_scale), (int(iteration) * int(instance_count) * float(gauge_value)))
            instance_count += 1
        return gauge_string

    def sample_counter_data(self, iteration=None, instances=None, metrics=None, endpointNum=None):
        count_value = 1.70205394e+08
        instance_scale = 0.7
        instance_count = 0
        counter_string = "# HELP sample_counter{:04d} Sample counter metric.\n".format(endpointNum)
        counter_string += "# HELP sample_counter{:04d} instance scale {} value scale {}\n".format(endpointNum, instance_scale, count_value)
        counter_string += "# TYPE sample_counter{:04d} counter\n".format(endpointNum)
        for i in range(instances):
            counter_string += "sample_counter{:04d}{{baz=\"{:.1f}\"}} {:.8e}\n".format(endpointNum, (i * instance_scale), (int(iteration) * instance_count * count_value))
            instance_count += 1
        return counter_string

    def sample_summary_data(self, iteration=None, instances=None, metrics=None, endpointNum=None):
        summary_value_q0 = 0.000159623
        summary_value_sum = 1.3912067700000001
        summary_value_count = 1818
        instance_scale = 0.25
        instance_count = 0
        summary_string =  "# HELP sample_summary{:04d} Sample summary metric has instances.\n".format(endpointNum)
        summary_string += "# HELP sample_summary{:04d} instance scale {} value scale {}\n".format(endpointNum, instance_scale, summary_value_q0)
        summary_string += "# TYPE sample_summary{:04d} summary\n".format(endpointNum)
        for i in range(instances):
            summary_string += "sample_summary{:04d}{{quantile=\"{}\"}} {:.9f}\n".format(endpointNum, (i * instance_scale), (int(iteration) * instance_count * float(summary_value_q0)))
            instance_count += 1
        summary_string += "sample_summary{:04d}_sum {:.16f}\n\
sample_summary{:04d}_count {}\n".format(endpointNum, (int(iteration) * float(summary_value_sum)),
                                        endpointNum, (int(iteration) * int(summary_value_count)))
        return summary_string

    def sample_histogram_data(self, iteration=None, instances=None, metrics=None, endpointNum=None):
        histogram_value_l1 = 5945
        histogram_value_sum = 45157
        histogram_value_count = 18135
        instance_scale = 2
        instance_count = 0
        histogram_string = "# HELP sample_histogram{:04d} Sample histogram metric has instances.\n".format(endpointNum)
        histogram_string += "# HELP sample_histogram{:04d} instance scale {} value scale {}\n".format(endpointNum, instance_scale, histogram_value_l1)
        histogram_string += "# TYPE sample_histogram{:04d} histogram\n".format(endpointNum)
        for i in range(instances):
            histogram_string += "sample_histogram{:04d}{{le=\"{}\"}} {:d}\n".format(endpointNum, (i * instance_scale), (int(iteration) * instance_count * histogram_value_l1))
            instance_count += 1
        histogram_string += "sample_histogram{:04d}_sum {:d}\n\
sample_histogram{:04d}_count {:d}\n".format(endpointNum, (int(instances) * int(histogram_value_sum)),
                                            endpointNum, (int(instances) * int(histogram_value_count)))
        return histogram_string

    def format_openmetrics_output(self, metrics=None, iteration=None, instances=None, error=None, endpointNum=None):
        metrics_remaining = metrics
        endpoint_string = ""
        while metrics_remaining >= 0:
            endpoint_string += self.sample_gauge_data(iteration, instances, metrics_remaining, (metrics-metrics_remaining))
            metrics_remaining -= 1
            iteration += 1
            if metrics_remaining <= 0:
                break
            endpoint_string += self.sample_counter_data(iteration, instances, metrics_remaining, (metrics-metrics_remaining))
            metrics_remaining -= 1
            iteration += 1
            if metrics_remaining <= 0:
                break
            endpoint_string += self.sample_summary_data(iteration, instances, metrics_remaining, (metrics-metrics_remaining))
#            metrics_remaining -= 4
            metrics_remaining -= instances
            iteration += 1
            if metrics_remaining <= 0:
                break
            endpoint_string += self.sample_histogram_data(iteration, instances, metrics_remaining, (metrics-metrics_remaining))
#            metrics_remaining -= 5
            metrics_remaining -= instances
            iteration += 1
        if error is not None:
            return endpoint_string[:int(len(endpoint_string)/int(error))]
        else:
            return endpoint_string

    def do_GET(self):
        time.sleep(float(self.server.args.delay))
        if len(activeEndpoint) > 0:
            endpoint_regex = re.compile(str(self.server.args.url))
            iteration = re.split(endpoint_regex, self.path)
            if len(iteration) > 1 and int(iteration[1]) in activeEndpoint:
                endpointNum = int(iteration[1])
                if int(self.server.args.error) and not activeEndpoint[endpointNum] % int(self.server.args.error) and activeEndpoint[endpointNum]:
                    error = int(self.server.args.error)
                else:
                    error = None
                self.send_response(200)
                self.end_headers()
                self.wfile.write(self.format_openmetrics_output(int(self.server.args.metrics), endpointNum*(activeEndpoint[endpointNum]), int(self.server.args.instances), error, endpointNum).encode())  # multiply by activeendpoint value or something?
                self.server.lock.acquire()
                try:
                    activeEndpoint[endpointNum] += 1
                    if activeEndpoint[endpointNum] >= int(self.server.args.limit):
                        del activeEndpoint[endpointNum] #do we even really need to do this?
                        self.server.pqueue.task_done()
                        try:
                            _next = self.server.pqueue.get_nowait()
                            activeEndpoint[int(_next)] = 0
                            write_endpoint_metadata(self.server.args, _next)
                        except queue.Empty:
                            pass
                finally:
                    self.server.lock.release()
            else:
                self.send_error(404)
        return

class PrometheusEndpoint(threading.Thread):
    def __init__(self, pqueue=None, args=None):
        threading.Thread.__init__(self)
        self.daemon = True
        self.pqueue = pqueue
        self.args = args
        self.lock = threading.Lock()
        self.server = None
        self.start()

    def run(self):
        try:
            endpoint = self.pqueue.get_nowait()
            self.lock.acquire()
            try:
                activeEndpoint[endpoint] = 0
                write_endpoint_metadata(self.args, endpoint)
            finally:
                self.lock.release()
        except queue.Empty:
            pass

        finally:
            httpd = BaseHTTPServer.HTTPServer(self.args.addr, FakeEndpoint, False)
            httpd.socket = self.args.sock
            httpd.pqueue = self.pqueue
            httpd.delay = self.args.delay
            httpd.args = self.args
            httpd.server_bind = self.server_close = lambda self: None
            httpd.lock = self.lock
            self.server = httpd
            self.server.serve_forever()


def parsing():
    parser = argparse.ArgumentParser(description='Setup a number of fake openmetrics endpoints for collection.')
    parser.add_argument('--endpoints', default=2, help='number of openmetrics endpoints to start at once')
    parser.add_argument('--metrics', default=5, help='number of metrics per openmetrics endpoint')
    parser.add_argument('--instances', default=5, help='number of instances per metric')
    parser.add_argument('--delay', default=0,  help='delay time (seconds) for each "slow" node')
    parser.add_argument('--limit', default=5, help='number of iterations/responses each endpoint limits itself to')
    parser.add_argument('--total', default=5, help='total number of endpoints to run')
    parser.add_argument('--port', default=10000, help='port to start fake endpoints on')
    parser.add_argument('--error', default=0, help='create errornous data in openmetrics endpoints')
    parser.add_argument('--url', default="foo&endpoint=", help='url for the endpoint, the endpoint number must be the last part')
    parser.add_argument('--output', default="/tmp", help='directory to create endpoint metadata in')
    args = parser.parse_args()
    return args

if __name__ == '__main__':

    args = parsing()
    pendpointQueue = queue.Queue()
    addr = ('', int(args.port))
    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    sock.bind(addr)
    sock.listen(int(args.port))
    args.addr = addr
    args.sock = sock

    for endpoint in range(int(args.total)):
        pendpointQueue.put(endpoint)

    pendpoints = [PrometheusEndpoint(pendpointQueue, args)
                  for i in range(int(args.endpoints))]

    pendpointQueue.join()
    sock.close()
